Skip to content

Conversation

@qiruiyangmeta
Copy link

@qiruiyangmeta qiruiyangmeta commented Oct 1, 2025

Add utility functions to enable load-balanced token sharding for context parallelism.

Purpose

Causal attention imposes a varying computational load for each token, as shown in the following figure. To ensure an even workload distribution, tokens should be partitioned across different context parallelism (CP) ranks. Specifically, the sequence is divided into 2 × cp_world_size chunks. Each CP rank i is assigned both the i-th chunk and the (2 × cp_world_size - i - 1)-th chunk. This approach helps balance the compute load among all CP ranks.
Screenshot 2025-10-01 at 3 43 02 PM

When tokens are distributed across context parallel (CP) ranks, gaps may appear in the block table. After compaction, tokens that are stored physically next to each other may not be logically consecutive. This is acceptable for CP because we only need to maintain the correct relative order of tokens for mapping purposes, rather than tracking their absolute positions in the block table.
image

Test Plan

pytest tests/v1/attention/test_context_parallel_attention.py

E2e tests to be added in the following PRs.

Test Result

======================================================= test session starts =======================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /home/qiruiyang/vllm
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 3 items                                                                                                                 

tests/v1/attention/test_context_parallel_attention.py ...                                                                   [100%]

======================================================== 3 passed in 0.51s ========================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces utility functions for token sharding to enable context parallelism, which is a significant feature for improving performance. The changes correctly add the context_parallel_size configuration and update the distributed state management accordingly. The core logic for sharding is well-encapsulated in the new vllm/v1/attention/backends/cp_utils.py file. However, I've identified a couple of issues in the new test file, tests/v1/attention/test_context_parallel_attention.py, including a critical bug in an assertion that needs to be fixed for the tests to be valid.

Comment on lines +216 to +218
assert num_comp_local == [
num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2]
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in this assertion. The expected value for num_comp_local should be a list of integers, but the expression [num_computed_tokens[1][-1] // 2] creates a list as the second element, resulting in [5, [4]]. The actual value of num_comp_local is [5, 4], which will cause this assertion to fail. For clarity and correctness, it's better to assert against the hardcoded expected value.

    assert num_comp_local == [5, 4]


def make_cached_request_state(id: int, prefill_len: int, decode_len: int,
num_computed_tokens: list[int]):
assert prefill_len + decode_len == sum(num_computed_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion in this helper function is incorrect. num_computed_tokens is a list of cumulative token counts, so sum(num_computed_tokens) does not represent the total number of tokens. The total number of tokens is the last element of the list. The assertion should be assert prefill_len + decode_len == num_computed_tokens[-1].

Suggested change
assert prefill_len + decode_len == sum(num_computed_tokens)
assert prefill_len + decode_len == num_computed_tokens[-1]

@qiruiyangmeta qiruiyangmeta force-pushed the prepare_inputs_for_cp branch from d07fd25 to f6b6ed3 Compare October 3, 2025 02:54
@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @qiruiyangmeta.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant